{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jMqk3Z8EciF8"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XbpNOB-vJVKu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bqdaOVRxWs8v"
      },
      "source": [
        "# Wiki Talk Comments Toxicity Prediction"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EG_KEDkodWsT"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/responsible_ai/fairness_indicators/tutorials/Fairness_Indicators_TFCO_Wiki_Case_Study\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/fairness-indicators/blob/master/g3doc/tutorials/Fairness_Indicators_TFCO_Wiki_Case_Study.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/fairness-indicators/tree/master/g3doc/tutorials/Fairness_Indicators_TFCO_Wiki_Case_Study.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/fairness-indicators/g3doc/tutorials/Fairness_Indicators_TFCO_Wiki_Case_Study.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y6T5tlXcdW7J"
      },
      "source": [
        "In this example, we consider the task of predicting whether a discussion comment posted on a Wiki talk page contains toxic content (i.e. contains content that is “rude, disrespectful or unreasonable”). We use a public \u003ca href=\"https://figshare.com/articles/Wikipedia_Talk_Labels_Toxicity/4563973\"\u003edataset\u003c/a\u003e released by the \u003ca href=\"https://conversationai.github.io/\"\u003eConversation AI\u003c/a\u003e project, which contains over 100k comments from the English Wikipedia that are annotated by crowd workers  (see [paper](https://arxiv.org/pdf/1610.08914.pdf) for labeling methodology).\n",
        "\n",
        "One of the challenges with this dataset is that a very small proportion of the comments cover sensitive topics such as sexuality or religion. As such, training a neural network model on this dataset leads to disparate performance on the smaller sensitive topics. This can mean that innocuous statements about those topics might get incorrectly flagged as ‘toxic’ at higher rates, causing speech to be unfairly censored\n",
        "\n",
        "By imposing constraints during training, we can train a *fairer* model that performs more equitably across the different topic groups. \n",
        "\n",
        "We will use the TFCO library to optimize for our fairness goal during training."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DG_C2gsAKV7x"
      },
      "source": [
        "## Installation\n",
        "\n",
        "Let's first install and import the relevant libraries. Note that you may have to restart your colab once after running the first cell because of outdated packages in the runtime. After doing so, there should be no further issues with imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0XOLn8Pyrc_s"
      },
      "outputs": [],
      "source": [
        "#@title pip installs\n",
        "!pip install -q -U pip==20.2\n",
        "\n",
        "!pip install git+https://github.com/google-research/tensorflow_constrained_optimization\n",
        "!pip install git+https://github.com/tensorflow/fairness-indicators"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2ZkQDo2xcDXU"
      },
      "source": [
        "Note that depending on when you run the cell below, you may receive a warning about the default version of TensorFlow in Colab switching to TensorFlow 2.X soon. You can safely ignore that warning as this notebook was designed to be compatible with TensorFlow 1.X and 2.X."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "nd_Y6CTnWs8w"
      },
      "outputs": [],
      "source": [
        "#@title Import Modules\n",
        "import io\n",
        "import os\n",
        "import shutil\n",
        "import sys\n",
        "import tempfile\n",
        "import time\n",
        "import urllib\n",
        "import zipfile\n",
        "\n",
        "import apache_beam as beam\n",
        "from IPython.display import display\n",
        "from IPython.display import HTML\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "import tensorflow as tf\n",
        "import tensorflow.keras as keras\n",
        "from tensorflow.keras import layers\n",
        "from tensorflow.keras.preprocessing import sequence\n",
        "from tensorflow.keras.preprocessing import text\n",
        "import tensorflow_constrained_optimization as tfco\n",
        "import tensorflow_model_analysis as tfma\n",
        "import fairness_indicators as fi\n",
        "from tensorflow_model_analysis.addons.fairness.view import widget_view\n",
        "from tensorflow_model_analysis.model_agnostic_eval import model_agnostic_evaluate_graph\n",
        "from tensorflow_model_analysis.model_agnostic_eval import model_agnostic_extractor\n",
        "from tensorflow_model_analysis.model_agnostic_eval import model_agnostic_predict as agnostic_predict"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GvqR564dLEVa"
      },
      "source": [
        "Though TFCO is compatible with eager and graph execution, this notebook assumes that eager execution is enabled by default. To ensure that nothing breaks, eager execution will be enabled in the cell below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "avMBqzjWct4Z"
      },
      "outputs": [],
      "source": [
        "#@title Enable Eager Execution and Print Versions\n",
        "if tf.__version__ \u003c \"2.0.0\":\n",
        "  tf.enable_eager_execution()\n",
        "  print(\"Eager execution enabled.\")\n",
        "else:\n",
        "  print(\"Eager execution enabled by default.\")\n",
        "\n",
        "print(\"TensorFlow \" + tf.__version__)\n",
        "print(\"TFMA \" + tfma.__version__)\n",
        "print(\"FI \" + fi.version.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YUJyWaAwWs83"
      },
      "source": [
        "## Hyper-parameters\n",
        "\n",
        "First, we set some hyper-parameters needed for the data preprocessing and model training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1aXlwlqTWs84"
      },
      "outputs": [],
      "source": [
        "hparams = {\n",
        "    \"batch_size\": 128,\n",
        "    \"cnn_filter_sizes\": [128, 128, 128],\n",
        "    \"cnn_kernel_sizes\": [5, 5, 5],\n",
        "    \"cnn_pooling_sizes\": [5, 5, 40],\n",
        "    \"constraint_learning_rate\": 0.01,\n",
        "    \"embedding_dim\": 100,\n",
        "    \"embedding_trainable\": False,\n",
        "    \"learning_rate\": 0.005,\n",
        "    \"max_num_words\": 10000,\n",
        "    \"max_sequence_length\": 250\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0PMs8Iwxq98C"
      },
      "source": [
        "## Load and pre-process dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DIe2JRDeWs87"
      },
      "source": [
        "Next, we download the dataset and preprocess it. The train, test and validation sets are provided as separate CSV files."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rcd2CV7pWs88"
      },
      "outputs": [],
      "source": [
        "toxicity_data_url = (\"https://raw.githubusercontent.com/conversationai/\"\n",
        "                     \"unintended-ml-bias-analysis/master/data/\")\n",
        "\n",
        "data_train = pd.read_csv(toxicity_data_url + \"wiki_train.csv\")\n",
        "data_test = pd.read_csv(toxicity_data_url + \"wiki_test.csv\")\n",
        "data_vali = pd.read_csv(toxicity_data_url + \"wiki_dev.csv\")\n",
        "\n",
        "data_train.head()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ojo617RIWs8_"
      },
      "source": [
        "The `comment` column contains the discussion comments and `is_toxic` column indicates whether or not a comment is annotated as toxic. \n",
        "\n",
        "In the following, we:\n",
        "1. Separate out the labels\n",
        "2. Tokenize the text comments\n",
        "3. Identify comments that contain sensitive topic terms \n",
        "\n",
        "First, we separate the labels from the train, test and validation sets. The labels are all binary (0 or 1)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mxo7ny90Ws9A"
      },
      "outputs": [],
      "source": [
        "labels_train = data_train[\"is_toxic\"].values.reshape(-1, 1) * 1.0\n",
        "labels_test = data_test[\"is_toxic\"].values.reshape(-1, 1) * 1.0\n",
        "labels_vali = data_vali[\"is_toxic\"].values.reshape(-1, 1) * 1.0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "alrWi6jUWs9C"
      },
      "source": [
        "Next, we tokenize the textual comments using the `Tokenizer` provided by `Keras`. We use the training set comments alone to build a vocabulary of tokens, and use them to convert all the comments into a (padded) sequence of tokens of the same length."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yvOTBsrHWs9D"
      },
      "outputs": [],
      "source": [
        "tokenizer = text.Tokenizer(num_words=hparams[\"max_num_words\"])\n",
        "tokenizer.fit_on_texts(data_train[\"comment\"])\n",
        "\n",
        "def prep_text(texts, tokenizer, max_sequence_length):\n",
        "    # Turns text into into padded sequences.\n",
        "    text_sequences = tokenizer.texts_to_sequences(texts)\n",
        "    return sequence.pad_sequences(text_sequences, maxlen=max_sequence_length)\n",
        "\n",
        "text_train = prep_text(data_train[\"comment\"], tokenizer, hparams[\"max_sequence_length\"])\n",
        "text_test = prep_text(data_test[\"comment\"], tokenizer, hparams[\"max_sequence_length\"])\n",
        "text_vali = prep_text(data_vali[\"comment\"], tokenizer, hparams[\"max_sequence_length\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cn5zbgp-Ws9F"
      },
      "source": [
        "Finally, we identify comments related to certain sensitive topic groups. We consider a subset of the \u003ca href=\"https://github.com/conversationai/unintended-ml-bias-analysis/blob/master/unintended_ml_bias/bias_madlibs_data/adjectives_people.txt\"\u003eidentity terms\u003c/a\u003e provided with the dataset and group them into\n",
        "four broad topic groups: *sexuality*, *gender identity*, *religion*, and *race*."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EnFfV2gEWs9G"
      },
      "outputs": [],
      "source": [
        "terms = {\n",
        "    'sexuality': ['gay', 'lesbian', 'bisexual', 'homosexual', 'straight', 'heterosexual'], \n",
        "    'gender identity': ['trans', 'transgender', 'cis', 'nonbinary'],\n",
        "    'religion': ['christian', 'muslim', 'jewish', 'buddhist', 'catholic', 'protestant', 'sikh', 'taoist'],\n",
        "    'race': ['african', 'african american', 'black', 'white', 'european', 'hispanic', 'latino', 'latina', \n",
        "             'latinx', 'mexican', 'canadian', 'american', 'asian', 'indian', 'middle eastern', 'chinese', \n",
        "             'japanese']}\n",
        "\n",
        "group_names = list(terms.keys())\n",
        "num_groups = len(group_names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ooI3F5M4Ws9I"
      },
      "source": [
        "We then create separate group membership matrices for the train, test and validation sets, where the rows correspond to comments, the columns correspond to the four sensitive groups, and each entry is a boolean indicating whether the comment contains a term from the topic group."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zO7PyNckWs9J"
      },
      "outputs": [],
      "source": [
        "def get_groups(text):\n",
        "    # Returns a boolean NumPy array of shape (n, k), where n is the number of comments, \n",
        "    # and k is the number of groups. Each entry (i, j) indicates if the i-th comment \n",
        "    # contains a term from the j-th group.\n",
        "    groups = np.zeros((text.shape[0], num_groups))\n",
        "    for ii in range(num_groups):\n",
        "        groups[:, ii] = text.str.contains('|'.join(terms[group_names[ii]]), case=False)\n",
        "    return groups\n",
        "\n",
        "groups_train = get_groups(data_train[\"comment\"])\n",
        "groups_test = get_groups(data_test[\"comment\"])\n",
        "groups_vali = get_groups(data_vali[\"comment\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GFAI6AB9Ws9L"
      },
      "source": [
        "As shown below, all four topic groups constitute only a small fraction of the overall dataset, and have varying proportions of toxic comments."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Ug4u_P9Ws9M"
      },
      "outputs": [],
      "source": [
        "print(\"Overall label proportion = %.1f%%\" % (labels_train.mean() * 100))\n",
        "\n",
        "group_stats = []\n",
        "for ii in range(num_groups):\n",
        "    group_proportion = groups_train[:, ii].mean()\n",
        "    group_pos_proportion = labels_train[groups_train[:, ii] == 1].mean()\n",
        "    group_stats.append([group_names[ii],\n",
        "                        \"%.2f%%\" % (group_proportion * 100), \n",
        "                        \"%.1f%%\" % (group_pos_proportion * 100)])\n",
        "group_stats = pd.DataFrame(group_stats, \n",
        "                           columns=[\"Topic group\", \"Group proportion\", \"Label proportion\"])\n",
        "group_stats"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aG5ZKKrVWs9O"
      },
      "source": [
        "We see that only 1.3% of the dataset contains comments related to sexuality. Among them, 37% of the comments have been annotated as being toxic. Note that this is significantly larger than the overall proportion of comments annotated as toxic. This could be because the few comments that used those identity terms did so in pejorative contexts. As mentioned above, this could cause our model to disporportionately misclassify comments as toxic when they include those terms. Since this is the concern, we'll make sure to look at the **False Positive Rate** when we evaluate the model's performance."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5DkJpKaLWs9P"
      },
      "source": [
        "## Build CNN toxicity prediction model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "niJ4KIJgWs9Q"
      },
      "source": [
        "Having prepared the dataset, we now build a `Keras` model for prediction toxicity. The model we use is a convolutional neural network (CNN) with the same architecture used by the Conversation AI project for their debiasing analysis. We adapt \u003ca href=\"https://github.com/conversationai/unintended-ml-bias-analysis/blob/master/unintended_ml_bias/model_tool.py\"\u003ecode\u003c/a\u003e provided by them to construct the model layers.\n",
        "\n",
        "The model uses an embedding layer to convert the text tokens to fixed-length vectors. This layer converts the input text sequence into a sequence of vectors, and passes them through several layers of convolution and pooling operations, followed by a final fully-connected layer.\n",
        "\n",
        "We make use of pre-trained GloVe word vector embeddings, which we download below. This may take a few minutes to complete."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yevbBL2oWs9Q"
      },
      "outputs": [],
      "source": [
        "zip_file_url = \"http://nlp.stanford.edu/data/glove.6B.zip\"\n",
        "zip_file = urllib.request.urlopen(zip_file_url)\n",
        "archive = zipfile.ZipFile(io.BytesIO(zip_file.read()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a11-YWDnWs9S"
      },
      "source": [
        "We use the downloaded GloVe embeddings to create an embedding matrix, where the rows contain the word embeddings for the tokens in the `Tokenizer`'s vocabulary. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bBS74MMYWs9T"
      },
      "outputs": [],
      "source": [
        "embeddings_index = {}\n",
        "glove_file = \"glove.6B.100d.txt\"\n",
        "\n",
        "with archive.open(glove_file) as f:\n",
        "    for line in f:\n",
        "        values = line.split()\n",
        "        word = values[0].decode(\"utf-8\") \n",
        "        coefs = np.asarray(values[1:], dtype=\"float32\")\n",
        "        embeddings_index[word] = coefs\n",
        "\n",
        "embedding_matrix = np.zeros((len(tokenizer.word_index) + 1, hparams[\"embedding_dim\"]))\n",
        "num_words_in_embedding = 0\n",
        "for word, i in tokenizer.word_index.items():\n",
        "    embedding_vector = embeddings_index.get(word)\n",
        "    if embedding_vector is not None:\n",
        "        num_words_in_embedding += 1\n",
        "        embedding_matrix[i] = embedding_vector"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t9NVp-_eWs9V"
      },
      "source": [
        "We are now ready to specify the `Keras` layers. We write a function to create a new model, which we will invoke whenever we wish to train a new model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_f_DhA6OWs9W"
      },
      "outputs": [],
      "source": [
        "def create_model():\n",
        "    model = keras.Sequential()\n",
        "\n",
        "    # Embedding layer.\n",
        "    embedding_layer = layers.Embedding(\n",
        "        embedding_matrix.shape[0],\n",
        "        embedding_matrix.shape[1],\n",
        "        weights=[embedding_matrix],\n",
        "        input_length=hparams[\"max_sequence_length\"],\n",
        "        trainable=hparams['embedding_trainable'])\n",
        "    model.add(embedding_layer)\n",
        "\n",
        "    # Convolution layers.\n",
        "    for filter_size, kernel_size, pool_size in zip(\n",
        "        hparams['cnn_filter_sizes'], hparams['cnn_kernel_sizes'],\n",
        "        hparams['cnn_pooling_sizes']):\n",
        "\n",
        "        conv_layer = layers.Conv1D(\n",
        "            filter_size, kernel_size, activation='relu', padding='same')\n",
        "        model.add(conv_layer)\n",
        "\n",
        "        pooled_layer = layers.MaxPooling1D(pool_size, padding='same')\n",
        "        model.add(pooled_layer)\n",
        "\n",
        "    # Add a flatten layer, a fully-connected layer and an output layer.\n",
        "    model.add(layers.Flatten())\n",
        "    model.add(layers.Dense(128, activation='relu'))\n",
        "    model.add(layers.Dense(1))\n",
        "    \n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CwcqYITBN7bW"
      },
      "source": [
        "We also define a method to set random seeds. This is done to ensure reproducible results."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C_1nsXntN98C"
      },
      "outputs": [],
      "source": [
        "def set_seeds():\n",
        "  np.random.seed(121212)\n",
        "  tf.compat.v1.set_random_seed(212121)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X-_fKjDtWs9Y"
      },
      "source": [
        "## Fairness indicators"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k009haGaWs9Z"
      },
      "source": [
        "We also write functions to plot fairness indicators."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9ZgGCAs8V-I"
      },
      "outputs": [],
      "source": [
        "def create_examples(labels, predictions, groups, group_names):\n",
        "  # Returns tf.examples with given labels, predictions, and group information.  \n",
        "  examples = []\n",
        "  sigmoid = lambda x: 1/(1 + np.exp(-x)) \n",
        "  for ii in range(labels.shape[0]):\n",
        "    example = tf.train.Example()\n",
        "    example.features.feature['toxicity'].float_list.value.append(\n",
        "        labels[ii])\n",
        "    example.features.feature['prediction'].float_list.value.append(\n",
        "        sigmoid(predictions[ii]))  # predictions need to be in [0, 1].\n",
        "    for jj in range(groups.shape[1]):\n",
        "      example.features.feature[group_names[jj]].bytes_list.value.append(\n",
        "          b'Yes' if groups[ii, jj] else b'No')\n",
        "    examples.append(example)\n",
        "  return examples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vESL-3dU9iiG"
      },
      "outputs": [],
      "source": [
        "def evaluate_results(labels, predictions, groups, group_names):\n",
        "  # Evaluates fairness indicators for given labels, predictions and group\n",
        "  # membership info.\n",
        "  examples = create_examples(labels, predictions, groups, group_names)\n",
        "\n",
        "  # Create feature map for labels, predictions and each group.\n",
        "  feature_map = {\n",
        "      'prediction': tf.io.FixedLenFeature([], tf.float32),\n",
        "      'toxicity': tf.io.FixedLenFeature([], tf.float32),\n",
        "  }\n",
        "  for group in group_names:\n",
        "    feature_map[group] = tf.io.FixedLenFeature([], tf.string)\n",
        "\n",
        "  # Serialize the examples.\n",
        "  serialized_examples = [e.SerializeToString() for e in examples]\n",
        "\n",
        "  BASE_DIR = tempfile.gettempdir()\n",
        "  OUTPUT_DIR = os.path.join(BASE_DIR, 'output')\n",
        "\n",
        "  with beam.Pipeline() as pipeline:\n",
        "    model_agnostic_config = agnostic_predict.ModelAgnosticConfig(\n",
        "              label_keys=['toxicity'],\n",
        "              prediction_keys=['prediction'],\n",
        "              feature_spec=feature_map)\n",
        "    \n",
        "    slices = [tfma.slicer.SingleSliceSpec()]\n",
        "    for group in group_names:\n",
        "      slices.append(\n",
        "          tfma.slicer.SingleSliceSpec(columns=[group]))\n",
        "\n",
        "    extractors = [\n",
        "            model_agnostic_extractor.ModelAgnosticExtractor(\n",
        "                model_agnostic_config=model_agnostic_config),\n",
        "            tfma.extractors.slice_key_extractor.SliceKeyExtractor(slices)\n",
        "        ]\n",
        "\n",
        "    metrics_callbacks = [\n",
        "      tfma.post_export_metrics.fairness_indicators(\n",
        "          thresholds=[0.5],\n",
        "          target_prediction_keys=['prediction'],\n",
        "          labels_key='toxicity'),\n",
        "      tfma.post_export_metrics.example_count()]\n",
        "\n",
        "    # Create a model agnostic aggregator.\n",
        "    eval_shared_model = tfma.types.EvalSharedModel(\n",
        "        add_metrics_callbacks=metrics_callbacks,\n",
        "        construct_fn=model_agnostic_evaluate_graph.make_construct_fn(\n",
        "            add_metrics_callbacks=metrics_callbacks,\n",
        "            config=model_agnostic_config))\n",
        "\n",
        "    # Run Model Agnostic Eval.\n",
        "    _ = (\n",
        "        pipeline\n",
        "        | beam.Create(serialized_examples)\n",
        "        | 'ExtractEvaluateAndWriteResults' \u003e\u003e\n",
        "          tfma.ExtractEvaluateAndWriteResults(\n",
        "              eval_shared_model=eval_shared_model,\n",
        "              output_path=OUTPUT_DIR,\n",
        "              extractors=extractors,\n",
        "              compute_confidence_intervals=True\n",
        "          )\n",
        "    )\n",
        "\n",
        "  fairness_ind_result = tfma.load_eval_result(output_path=OUTPUT_DIR)\n",
        "\n",
        "  # Also evaluate accuracy of the model.\n",
        "  accuracy = np.mean(labels == (predictions \u003e 0.0))\n",
        "\n",
        "  return fairness_ind_result, accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W3Sp7mpsWs9f"
      },
      "outputs": [],
      "source": [
        "def plot_fairness_indicators(eval_result, title):\n",
        "  fairness_ind_result, accuracy = eval_result\n",
        "  display(HTML(\"\u003ccenter\u003e\u003ch2\u003e\" + title + \n",
        "               \" (Accuracy = %.2f%%)\" % (accuracy * 100) + \"\u003c/h2\u003e\u003c/center\u003e\"))\n",
        "  widget_view.render_fairness_indicator(fairness_ind_result)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WqLdtgI42fxb"
      },
      "outputs": [],
      "source": [
        "def plot_multi_fairness_indicators(multi_eval_results):\n",
        " \n",
        "  multi_results = {}\n",
        "  multi_accuracy = {}\n",
        "  for title, (fairness_ind_result, accuracy) in multi_eval_results.items():\n",
        "    multi_results[title] = fairness_ind_result\n",
        "    multi_accuracy[title] = accuracy\n",
        "  \n",
        "  title_str = \"\u003ccenter\u003e\u003ch2\u003e\"\n",
        "  for title in multi_eval_results.keys():\n",
        "      title_str+=title + \" (Accuracy = %.2f%%)\" % (multi_accuracy[title] * 100) + \"; \"\n",
        "  title_str=title_str[:-2]\n",
        "  title_str+=\"\u003c/h2\u003e\u003c/center\u003e\"\n",
        "  # fairness_ind_result, accuracy = eval_result\n",
        "  display(HTML(title_str))\n",
        "  widget_view.render_fairness_indicator(multi_eval_results=multi_results)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8aWNc4CdWs9h"
      },
      "source": [
        "## Train unconstrained model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DuSA8qL7Ws9i"
      },
      "source": [
        "For the first model we train, we optimize a simple cross-entropy loss *without* any constraints.."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0g50bauHWs9j"
      },
      "outputs": [],
      "source": [
        "# Set random seed for reproducible results.\n",
        "set_seeds()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YsCoHMG_iIzc"
      },
      "source": [
        "**Note**: The following code cell can take ~8 minutes to run."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tamJiG3FiDYW"
      },
      "outputs": [],
      "source": [
        "# Optimizer and loss.\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=hparams[\"learning_rate\"])\n",
        "loss = lambda y_true, y_pred: tf.keras.losses.binary_crossentropy(\n",
        "    y_true, y_pred, from_logits=True)\n",
        "\n",
        "# Create, compile and fit model.\n",
        "model_unconstrained = create_model()\n",
        "model_unconstrained.compile(optimizer=optimizer, loss=loss)\n",
        "\n",
        "model_unconstrained.fit(\n",
        "    x=text_train, y=labels_train, batch_size=hparams[\"batch_size\"], epochs=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "p7AvIdktWs9t"
      },
      "source": [
        "Having trained the unconstrained model, we plot various evaluation metrics for the model on the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tHV40_21lRL6"
      },
      "outputs": [],
      "source": [
        "scores_unconstrained_test = model_unconstrained.predict(text_test)\n",
        "eval_result_unconstrained = evaluate_results(\n",
        "    labels_test, scores_unconstrained_test, groups_test, group_names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AJpRuN0EOeyG"
      },
      "source": [
        "As explained above, we are concentrating on the false positive rate. In their current version (0.1.2), Fairness Indicators select false negative rate by default. After running the line below, go ahead and deselect false_negative_rate and select false_positive_rate to look at the metric we are interested in."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2fwNpfou4yvP"
      },
      "outputs": [],
      "source": [
        "plot_fairness_indicators(eval_result_unconstrained, \"Unconstrained\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J3TbAenkGM7P"
      },
      "source": [
        "While the overall false positive rate is less than 2%, the false positive rate on the sexuality-related comments is significantly higher. This is because the sexuality group is very small in size, and has a disproportionately higher fraction of comments annotated as toxic. Hence, training a model without constraints results in the model believing that sexuality-related terms are a strong indicator of toxicity."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KmxyAo9hWs9w"
      },
      "source": [
        "## Train with constraints on false positive rates"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l3dYUchIWs9w"
      },
      "source": [
        "To avoid large differences in false positive rates across different groups, we \n",
        "next train a model by constraining the false positive rates for each group to be within a desired limit. In this case, we will optimize the error rate of the model subject to the *per-group false positive rates being lesser or equal to 2%*.\n",
        "\n",
        "Training on minibatches with per-group constraints can be challenging for this dataset, however, as the groups we wish to constraint are all small in size, and it's likely that the individual minibatches contain very few examples from each group. Hence the gradients we compute during training will be noisy, and result in the model converging very slowly. \n",
        "\n",
        "To mitigate this problem, we recommend using two streams of minibatches, with the first stream formed as before from the entire training set, and the second stream formed solely from the sensitive group examples. We will compute the objective using minibatches from the first stream and the per-group constraints using minibatches from the second stream. Because the batches from the second stream are likely to contain a larger number of examples from each group, we expect our updates to be less noisy.\n",
        "\n",
        "We create separate features, labels and groups tensors to hold the minibatches from the two streams."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vMuuTOEOWs9x"
      },
      "outputs": [],
      "source": [
        "# Set random seed.\n",
        "set_seeds()\n",
        "\n",
        "# Features tensors.\n",
        "batch_shape = (hparams[\"batch_size\"], hparams['max_sequence_length'])\n",
        "features_tensor = tf.Variable(np.zeros(batch_shape, dtype='int32'), name='x')\n",
        "features_tensor_sen = tf.Variable(np.zeros(batch_shape, dtype='int32'), name='x_sen')\n",
        "\n",
        "# Labels tensors.\n",
        "batch_shape = (hparams[\"batch_size\"], 1)\n",
        "labels_tensor = tf.Variable(np.zeros(batch_shape, dtype='float32'), name='labels')\n",
        "labels_tensor_sen = tf.Variable(np.zeros(batch_shape, dtype='float32'), name='labels_sen')\n",
        "\n",
        "# Groups tensors.\n",
        "batch_shape = (hparams[\"batch_size\"], num_groups)\n",
        "groups_tensor_sen = tf.Variable(np.zeros(batch_shape, dtype='float32'), name='groups_sen')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-wh26V7nWs9z"
      },
      "source": [
        "We instantiate a new model, and compute predictions for minibatches from the two streams."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kawyrkQIWs9z"
      },
      "outputs": [],
      "source": [
        "# Create model, and separate prediction functions for the two streams. \n",
        "# For the predictions, we use a nullary function returning a Tensor to support eager mode.\n",
        "model_constrained = create_model()\n",
        "\n",
        "def predictions():\n",
        "  return model_constrained(features_tensor)\n",
        "\n",
        "def predictions_sen():\n",
        "  return model_constrained(features_tensor_sen)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UG9t7dw1Ws91"
      },
      "source": [
        "We then set up a constrained optimization problem with the error rate as the objective and with constraints on the per-group false positive rate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EhKAMGSJWs93"
      },
      "outputs": [],
      "source": [
        "epsilon = 0.02  # Desired false-positive rate threshold.\n",
        "\n",
        "# Set up separate contexts for the two minibatch streams.\n",
        "context = tfco.rate_context(predictions, lambda:labels_tensor)\n",
        "context_sen = tfco.rate_context(predictions_sen, lambda:labels_tensor_sen)\n",
        "\n",
        "# Compute the objective using the first stream.\n",
        "objective = tfco.error_rate(context)\n",
        "\n",
        "# Compute the constraint using the second stream.\n",
        "# Subset the examples belonging to the \"sexuality\" group from the second stream \n",
        "# and add a constraint on the group's false positive rate.\n",
        "context_sen_subset = context_sen.subset(lambda: groups_tensor_sen[:, 0] \u003e 0)\n",
        "constraint = [tfco.false_positive_rate(context_sen_subset) \u003c= epsilon]\n",
        "\n",
        "# Create a rate minimization problem.\n",
        "problem = tfco.RateMinimizationProblem(objective, constraint)\n",
        "\n",
        "# Set up a constrained optimizer.\n",
        "optimizer = tfco.ProxyLagrangianOptimizerV2(\n",
        "    optimizer=tf.keras.optimizers.Adam(learning_rate=hparams[\"learning_rate\"]),\n",
        "    num_constraints=problem.num_constraints)\n",
        "\n",
        "# List of variables to optimize include the model weights, \n",
        "# and the trainable variables from the rate minimization problem and \n",
        "# the constrained optimizer.\n",
        "var_list = (model_constrained.trainable_weights + list(problem.trainable_variables) +\n",
        "            optimizer.trainable_variables())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CoFWd8wMWs94"
      },
      "source": [
        "We are ready to train the model. We maintain a separate counter for the two minibatch streams. Every time we perform a gradient update, we will have to copy the minibatch contents from the first stream to the tensors `features_tensor` and `labels_tensor`, and the minibatch contents from the second stream to the tensors `features_tensor_sen`, `labels_tensor_sen` and `groups_tensor_sen`.\n",
        "\n",
        "**Note**: The following code cell may take ~12 minutes to run."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zbXohC6vWs95"
      },
      "outputs": [],
      "source": [
        "# Indices of sensitive group members.\n",
        "protected_group_indices = np.nonzero(groups_train.sum(axis=1))[0]\n",
        "\n",
        "num_examples = text_train.shape[0]\n",
        "num_examples_sen = protected_group_indices.shape[0]\n",
        "batch_size = hparams[\"batch_size\"]\n",
        "\n",
        "# Number of steps needed for one epoch over the training sample.\n",
        "num_steps = int(num_examples / batch_size)\n",
        "\n",
        "start_time = time.time()\n",
        "\n",
        "# Loop over minibatches.\n",
        "for batch_index in range(num_steps):\n",
        "    # Indices for current minibatch in the first stream.\n",
        "    batch_indices = np.arange(\n",
        "        batch_index * batch_size, (batch_index + 1) * batch_size)\n",
        "    batch_indices = [ind % num_examples for ind in batch_indices]\n",
        "\n",
        "    # Indices for current minibatch in the second stream.\n",
        "    batch_indices_sen = np.arange(\n",
        "        batch_index * batch_size, (batch_index + 1) * batch_size)\n",
        "    batch_indices_sen = [protected_group_indices[ind % num_examples_sen]\n",
        "                         for ind in batch_indices_sen]\n",
        "\n",
        "    # Assign features, labels, groups from the minibatches to the respective tensors.\n",
        "    features_tensor.assign(text_train[batch_indices, :])\n",
        "    labels_tensor.assign(labels_train[batch_indices])\n",
        "\n",
        "    features_tensor_sen.assign(text_train[batch_indices_sen, :])\n",
        "    labels_tensor_sen.assign(labels_train[batch_indices_sen])\n",
        "    groups_tensor_sen.assign(groups_train[batch_indices_sen, :])\n",
        "\n",
        "    # Gradient update.\n",
        "    optimizer.minimize(problem, var_list=var_list)\n",
        "    \n",
        "    # Record and print batch training stats every 10 steps.\n",
        "    if (batch_index + 1) % 10 == 0 or batch_index in (0, num_steps - 1):\n",
        "      hinge_loss = problem.objective()\n",
        "      max_violation = max(problem.constraints())\n",
        "\n",
        "      elapsed_time = time.time() - start_time\n",
        "      sys.stdout.write(\n",
        "          \"\\rStep %d / %d: Elapsed time = %ds, Loss = %.3f, Violation = %.3f\" % \n",
        "          (batch_index + 1, num_steps, elapsed_time, hinge_loss, max_violation))\n",
        "    "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DdJfplDpWs97"
      },
      "source": [
        "Having trained the constrained model, we plot various evaluation metrics for the model on the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jEerPEwLhfTN"
      },
      "outputs": [],
      "source": [
        "scores_constrained_test = model_constrained.predict(text_test)\n",
        "eval_result_constrained = evaluate_results(\n",
        "    labels_test, scores_constrained_test, groups_test, group_names)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ustp5z7xQnHI"
      },
      "source": [
        "As with last time, remember to select false_positive_rate."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ztK7iM4LjKmT"
      },
      "outputs": [],
      "source": [
        "plot_fairness_indicators(eval_result_constrained, \"Constrained\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6P6dxSg5_mTu"
      },
      "outputs": [],
      "source": [
        "multi_results = {\n",
        "    'constrained':eval_result_constrained,\n",
        "    'unconstrained':eval_result_unconstrained,\n",
        "}\n",
        "plot_multi_fairness_indicators(multi_eval_results=multi_results)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EfKo5O3QWs9-"
      },
      "source": [
        "As we can see from the Fairness Indicators, compared to the unconstrained model the constrained model yields significantly lower false positive rates for the sexuality-related comments, and does so with only a slight dip in the overall accuracy."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Fairness Indicators TFCO Wiki Comments Case Study.ipynb",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
